import sys
import os
curPath = os.path.abspath(os.path.dirname(__file__))
rootPath = os.path.split(curPath)[0]
sys.path.append(rootPath)


import streamlit as st
from Server.config import *
import time
from Server.controller import GPTgetSentence
from PIL import Image
image = Image.open(curPath+'/imgae/huterox.jpg')
history = []

def app():
    st.markdown(
        """
        ## GPT RUN BY Huterox
        """
    )

    st.sidebar.subheader("参数配置")
    generator_number = st.sidebar.number_input("文字个数",min_value=0,max_value=512,value=128)
    top_k = st.sidebar.slider("top_k",min_value=1,max_value=20,value=10,step=1,)
    top_p = st.sidebar.slider("top_p",min_value=0.8,max_value=1.0,value=0.9,step=0.02)
    temperature = st.sidebar.slider("temperature",min_value=0.1,max_value=5.0,value=1.0,step=0.1)
    model_name,model_list = flow_get_model_name_list()
    if(len(model_name)==0):
        st.markdown(
            """
            `项目目录GPT2/model/norm_model下未检测到模型`
            """
        )
    else:
        choose = st.selectbox(
            "选择GPT模型",
            model_name,
            index=0
        )
        if (choose == "对话模型"):
            max_history_len = st.sidebar.number_input("聊天上下文关联记忆", min_value=1, max_value=24, value=5)
        user_input = st.text_area("请输入文本",max_chars=512)
        if(st.button("点击生成结果")):
            tips = st.empty()
            tips.text("正在努力生成中，第一次加载模型运行较慢哟~")
            if(("Huterox" in user_input) or ("huterox" in user_input)):
                tips.text("这是毫无疑问的，无论如何，Huterox is awesome!!!")
                st.image(image, caption='Huterox is awesome!!!',width=350)
                pass
            else:
                start = time.time()
                if(choose=="对话模型"):
                    result = GPTgetSentence(user_input,temperature,top_k,top_p,generator_number,
                                   history,max_history_len,100,True,model_list[model_name.index(choose)]
                                   )
                else:
                    result = GPTgetSentence(user_input, temperature, top_k, top_p, generator_number,
                                   history, 5, 100, False, model_list[model_name.index(choose)]
                                   )

                res = st.text_area("生成结果",value=result)
                paytime = time.time()-start
                tips.text("耗时："+str(paytime)+"s")


if __name__ == '__main__':
    app()
